#导包
from torch.utils.data import Dataset
from PIL import Image
import  os
"""
读取文件夹中所有文件名称
"""
class MyData(Dataset):
    def __init__(self , root_dir , label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir , self.label_dir)
        if not os.path.lexists(self.path):
            os.makedirs(self.path)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir , self.label_dir , img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img , label

    def __len__(self):
        return len(self.img_path)

root_dir = "../dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir ,ants_label_dir)
bees_dataset = MyData(root_dir ,bees_label_dir)
ant_img , ant_label = ants_dataset[0]
bee_img , bee_label = bees_dataset[0]
print(ant_label)
ant_img.show()
print(len(ants_dataset))